//
// Copyright 2018 Ettus Research, A National Instruments Company
//
// SPDX-License-Identifier: LGPL-3.0-or-later
//
// Module: torus_2d_dor_router_single_sw
// Description: 
//   This module implements the router for a 2-dimentional (2d) 
//   torus network that uses dimension order routing (dor) and has a 
//   single underlying switch (single_sw). It uses AXI-Stream for all of its
//   links. 
//   The torus topology, routing algorithms and the router architecture is 
//   described in README.md in this directory. 
// Parameters:
//   - WIDTH: Width of the AXI-Stream data bus
//   - DIM_SIZE: Number of routers alone one dimension
//   - XB_ADDR_X: The X-coordinate of this router in the topology
//   - XB_ADDR_Y: The Y-coordinate of this router in the topology
//   - TERM_BUFF_SIZE: log2 of the ingress terminal buffer size (in words)
//   - XB_BUFF_SIZE: log2 of the ingress inter-router buffer size (in words)
//   - ROUTING_ALLOC: Algorithm to allocate routing paths between routers.
//     * WORMHOLE: Allocate route as soon as first word in pkt arrives
//     * CUT-THROUGH: Allocate route only after the full pkt arrives
//   - SWITCH_ALLOC: Algorithm to allocate the switch
//     * PRIO: Priority based. Priority: Y-dim > X-dim > Term
//     * ROUND-ROBIN: Round robin input port allocation
// Signals:
//   - *_axis_term_*: Terminal ports (master/slave)
//   - *_axis_xdim_*: Inter-router X-dim connections (master/slave)
//   - *_axis_ydim_*: Inter-router Y-dim connections (master/slave)
//

module torus_2d_dor_router_single_sw #(
  parameter                         WIDTH          = 64,
  parameter                         DIM_SIZE       = 4,
  parameter [$clog2(DIM_SIZE)-1:0]  XB_ADDR_X      = 0,
  parameter [$clog2(DIM_SIZE)-1:0]  XB_ADDR_Y      = 0,
  parameter                         TERM_BUFF_SIZE = 5,
  parameter                         XB_BUFF_SIZE   = 5,
  parameter                         ROUTING_ALLOC  = "WORMHOLE",
  parameter                         SWITCH_ALLOC   = "PRIO"
) (
  // Clocks and resets
  input  wire             clk,
  input  wire             reset,

  // Terminal connections
  input  wire [WIDTH-1:0] s_axis_term_tdata,
  input  wire             s_axis_term_tlast,
  input  wire             s_axis_term_tvalid,
  output wire             s_axis_term_tready,
  output wire [WIDTH-1:0] m_axis_term_tdata,
  output wire             m_axis_term_tlast,
  output wire             m_axis_term_tvalid,
  input  wire             m_axis_term_tready,

  // X-dimension inter-XB connections
  input  wire [WIDTH-1:0] s_axis_xdim_tdata,
  input  wire [0:0]       s_axis_xdim_tdest,
  input  wire             s_axis_xdim_tlast,
  input  wire             s_axis_xdim_tvalid,
  output wire             s_axis_xdim_tready,
  output wire [WIDTH-1:0] m_axis_xdim_tdata,
  output wire [0:0]       m_axis_xdim_tdest,
  output wire             m_axis_xdim_tlast,
  output wire             m_axis_xdim_tvalid,
  input  wire             m_axis_xdim_tready,

  // Y-dimension inter-XB connections
  input  wire [WIDTH-1:0] s_axis_ydim_tdata,
  input  wire [0:0]       s_axis_ydim_tdest,
  input  wire             s_axis_ydim_tlast,
  input  wire             s_axis_ydim_tvalid,
  output wire             s_axis_ydim_tready,
  output wire [WIDTH-1:0] m_axis_ydim_tdata,
  output wire [0:0]       m_axis_ydim_tdest,
  output wire             m_axis_ydim_tlast,
  output wire             m_axis_ydim_tvalid,
  input  wire             m_axis_ydim_tready
);

  //-------------------------------------------------
  // Routing and switch allocation functions
  //-------------------------------------------------

  // mesh_node_mapping.vh file contains the mapping between the node number
  // and its XY coordinates. It is autogenerated and defines the node_to_xdst
  // and node_to_ydst functions.
  `include "mesh_node_mapping.vh"

  localparam [1:0] SW_DEST_TERM = 2'd0;
  localparam [1:0] SW_DEST_XDIM = 2'd1;
  localparam [1:0] SW_DEST_YDIM = 2'd2;
  localparam [1:0] SW_NUM_DESTS = 2'd3;

  // The compute_switch_tdest function is the destination selector
  // i.e. it will inspecte the bottom $clog2(DIM_SIZE)*2 bits of the
  // first word of a packet and determine the destination of the packet. 
  function [2:0] compute_switch_tdest;
    input [WIDTH-1:0] header;
    reg [$clog2(DIM_SIZE)-1:0] xdst, ydst;
    reg signed [$clog2(DIM_SIZE):0] xdiff, ydiff;
  begin
    xdst  = node_to_xdst(header);
    ydst  = node_to_ydst(header);
    xdiff = xdst - XB_ADDR_X;
    ydiff = ydst - XB_ADDR_Y;
    // Routing logic
    // - MSB is the VC, 2 LSBs are the router destination
    // - Long journeys get VC = 1 to bypass local traffic
    if (xdiff == 'd0 && ydiff == 'd0) begin
      compute_switch_tdest = {1'b0 /* VC don't care */,  SW_DEST_TERM};
    end else if (xdiff != 'd0) begin
      compute_switch_tdest = {(xdiff < 0), SW_DEST_XDIM};
    end else begin
      compute_switch_tdest = {(ydiff < 0), SW_DEST_YDIM};
    end
    //$display("xdst=%d, ydst=%d, xaddr=%d, yaddr=%d, dst=%d", xdst, ydst, XB_ADDR_X, XB_ADDR_Y, compute_switch_tdest);
  end
  endfunction

  // The compute_switch_alloc function is the switch allocation function 
  // i.e. it chooses which input port reserves the switch for packet transfer.
  // After the switch is allocated, all other ports will be backpressured until
  // the packet finishes transferring.
  function [1:0] compute_switch_alloc;
    input [2:0] pkt_waiting;
    input [1:0] last_alloc;
  begin
    if (pkt_waiting == 3'b000) begin
      compute_switch_alloc = SW_DEST_TERM;
    end else if (pkt_waiting == 3'b001) begin
      compute_switch_alloc = SW_DEST_TERM;
    end else if (pkt_waiting == 3'b010) begin
      compute_switch_alloc = SW_DEST_XDIM;
    end else if (pkt_waiting == 3'b100) begin
      compute_switch_alloc = SW_DEST_YDIM;
    end else begin
      if (SWITCH_ALLOC == "PRIO") begin
        // Priority: Y-dim > X-dim > Term
        if (pkt_waiting[SW_DEST_YDIM])
          compute_switch_alloc = SW_DEST_YDIM;
        else if (pkt_waiting[SW_DEST_XDIM])
          compute_switch_alloc = SW_DEST_XDIM;
        else
          compute_switch_alloc = SW_DEST_TERM;
      end else begin
        // Round-robin
        if (pkt_waiting[(last_alloc + 3'd1) % SW_NUM_DESTS])
          compute_switch_alloc = (last_alloc + 3'd1) % SW_NUM_DESTS;
        else if (pkt_waiting[(last_alloc + 3'd2) % SW_NUM_DESTS])
          compute_switch_alloc = (last_alloc + 3'd2) % SW_NUM_DESTS;
        else
          compute_switch_alloc = last_alloc;
      end
    end
  end
  endfunction

  //-------------------------------------------------
  // Ingress buffers
  //-------------------------------------------------
  wire [WIDTH-1:0] ydim_in_data , xdim_in_data , term_in_data ;
  wire [2:0]       ydim_in_dest , xdim_in_dest , term_in_dest ; 
  wire             ydim_in_last , xdim_in_last , term_in_last ; 
  wire             ydim_in_valid, xdim_in_valid, term_in_valid;
  wire             ydim_in_ready, xdim_in_ready, term_in_ready;

  // Data coming in from the terminal is gated until a full packet arrives
  // in order to minimize the switch allocation time per packet.
  axi_packet_gate #(
    .WIDTH(WIDTH), .SIZE(TERM_BUFF_SIZE)
  ) term_in_pkt_gate_i (
    .clk      (clk), 
    .reset    (reset), 
    .clear    (1'b0),
    .i_tdata  (s_axis_term_tdata),
    .i_tlast  (s_axis_term_tlast),
    .i_tvalid (s_axis_term_tvalid),
    .i_tready (s_axis_term_tready),
    .i_terror (1'b0),
    .o_tdata  (term_in_data),
    .o_tlast  (term_in_last),
    .o_tvalid (term_in_valid),
    .o_tready (term_in_ready)
  );
  assign term_in_dest = compute_switch_tdest(term_in_data);

  // The XY directions have buffers with 2 virtual channels to minimize the
  // possibility of a deadlock.
  axis_ingress_vc_buff #(
    .WIDTH(WIDTH), .NUM_VCS(2),
    .SIZE(XB_BUFF_SIZE),
    .ROUTING(ROUTING_ALLOC)
  ) xdim_in_vc_buf_i (
    .clk           (clk), 
    .reset         (reset), 
    .s_axis_tdata  (s_axis_xdim_tdata),
    .s_axis_tdest  (s_axis_xdim_tdest),
    .s_axis_tlast  (s_axis_xdim_tlast),
    .s_axis_tvalid (s_axis_xdim_tvalid),
    .s_axis_tready (s_axis_xdim_tready),
    .m_axis_tdata  (xdim_in_data),
    .m_axis_tlast  (xdim_in_last),
    .m_axis_tvalid (xdim_in_valid),
    .m_axis_tready (xdim_in_ready)
  );
  assign xdim_in_dest = compute_switch_tdest(xdim_in_data);

  axis_ingress_vc_buff #(
    .WIDTH(WIDTH), .NUM_VCS(2),
    .SIZE(XB_BUFF_SIZE),
    .ROUTING(ROUTING_ALLOC)
  ) ydim_in_vc_buf_i (
    .clk           (clk), 
    .reset         (reset), 
    .s_axis_tdata  (s_axis_ydim_tdata ),
    .s_axis_tdest  (s_axis_ydim_tdest ),
    .s_axis_tlast  (s_axis_ydim_tlast ),
    .s_axis_tvalid (s_axis_ydim_tvalid),
    .s_axis_tready (s_axis_ydim_tready),
    .m_axis_tdata  (ydim_in_data ),
    .m_axis_tlast  (ydim_in_last ),
    .m_axis_tvalid (ydim_in_valid),
    .m_axis_tready (ydim_in_ready)
  );
  assign ydim_in_dest = compute_switch_tdest(ydim_in_data);

  //-------------------------------------------------
  // Switch
  //-------------------------------------------------

  // Track the input packet state
  localparam [0:0] PKT_ST_HEAD = 1'b0;
  localparam [0:0] PKT_ST_BODY = 1'b1;
  reg [0:0] pkt_state = PKT_ST_HEAD;

  // The switch only accept packets on a single port at a time.
  wire sw_in_ready = |({ydim_in_ready, xdim_in_ready, term_in_ready});
  wire sw_in_valid = |({ydim_in_valid, xdim_in_valid, term_in_valid});
  wire sw_in_last  = |({ydim_in_last&ydim_in_valid, xdim_in_last&xdim_in_valid, term_in_last&term_in_valid});

  always @(posedge clk) begin
    if (reset) begin
      pkt_state <= PKT_ST_HEAD;
    end else if (sw_in_valid & sw_in_ready) begin
      pkt_state <= sw_in_last ? PKT_ST_HEAD : PKT_ST_BODY;
    end
  end

  // The switch requires the allocation to stay valid until the
  // end of the packet. We also might need to keep the previous
  // packet's allocation to compute the current one
  wire [1:0] switch_alloc;
  reg  [1:0] prev_switch_alloc = SW_DEST_TERM;
  reg  [1:0] pkt_switch_alloc  = SW_DEST_TERM;

  always @(posedge clk) begin
    if (reset) begin
      prev_switch_alloc <= SW_DEST_TERM;
      pkt_switch_alloc <= SW_DEST_TERM;
    end else if (sw_in_valid & sw_in_ready) begin
      if (pkt_state == PKT_ST_HEAD)
        pkt_switch_alloc <= switch_alloc;
      if (sw_in_last)
        prev_switch_alloc <= switch_alloc;
    end
  end

  assign switch_alloc = (sw_in_valid && pkt_state == PKT_ST_HEAD) ? 
    compute_switch_alloc({ydim_in_valid, xdim_in_valid, term_in_valid}, prev_switch_alloc) :
    pkt_switch_alloc;

  wire term_tdest_discard;
  axis_switch #(
    .DATA_W(WIDTH), .DEST_W(1), .IN_PORTS(3), .OUT_PORTS(3)
  ) switch_i (
    .clk           (clk),
    .reset         (reset),
    .s_axis_tdata  ({ydim_in_data , xdim_in_data , term_in_data }),
    .s_axis_tdest  ({ydim_in_dest , xdim_in_dest , term_in_dest }), 
    .s_axis_tlast  ({ydim_in_last , xdim_in_last , term_in_last }),
    .s_axis_tvalid ({ydim_in_valid, xdim_in_valid, term_in_valid}),
    .s_axis_tready ({ydim_in_ready, xdim_in_ready, term_in_ready}),
    .s_axis_alloc  (switch_alloc),
    .m_axis_tdata  ({m_axis_ydim_tdata,  m_axis_xdim_tdata,  m_axis_term_tdata }),
    .m_axis_tdest  ({m_axis_ydim_tdest,  m_axis_xdim_tdest,  term_tdest_discard}),
    .m_axis_tlast  ({m_axis_ydim_tlast,  m_axis_xdim_tlast,  m_axis_term_tlast }),
    .m_axis_tvalid ({m_axis_ydim_tvalid, m_axis_xdim_tvalid, m_axis_term_tvalid}),
    .m_axis_tready ({m_axis_ydim_tready, m_axis_xdim_tready, m_axis_term_tready})
  );

endmodule

